"""Jitter and phase utilities for Simulation B.

This module provides functions to sample a random relative phase offset
\(\Delta\varphi\) from Gaussian or uniform distributions, wrap it to
\((-\pi, \pi]\), quantise it into phase bins, and classify screen
columns based on phase windows.  These functions are used by the
driver to emulate a pure dephasing instrument.
"""

from __future__ import annotations

import math
from typing import Callable, Tuple

import numpy as np

def sample_gaussian(sigma: float, rng: np.random.Generator) -> float:
    """Draw a phase offset from a zero‑mean Gaussian with standard deviation sigma (radians).

    The value is wrapped into the range (-π, π] using ``wrap_to_pi``.
    """
    if sigma == 0.0:
        return 0.0
    raw = rng.normal(0.0, sigma)
    return wrap_to_pi(raw)


def sample_uniform(a: float, rng: np.random.Generator) -> float:
    """Draw a phase offset uniformly from [-a, a].

    The value is wrapped into the range (-π, π] using ``wrap_to_pi``.
    """
    if a == 0.0:
        return 0.0
    raw = rng.uniform(-a, a)
    return wrap_to_pi(raw)


def wrap_to_pi(phi: float) -> float:
    """Wrap an angle to the interval (-π, π]."""
    return (phi + math.pi) % (2.0 * math.pi) - math.pi


def quantise_phase(delta_phi: float, bin_width: float) -> int:
    """Quantise a phase offset to the nearest integer number of bins.

    ``bin_width`` = 2π / B.  The result is an integer which may be negative
    or positive; calling ``wrap_bins`` later maps it into [0, B-1].
    """
    return int(round(delta_phi / bin_width))


def phi_geom(x: int, x0: int, period_px: int) -> float:
    """Geometric phase at screen column x.

    Uses \(\phi_{\text{geom}}(x) = 2π (x - x0) / P\).
    """
    return 2.0 * math.pi * (x - x0) / period_px


def phase_bin(phi: float, B: int) -> int:
    """Map a geometric phase in radians to a bin index in [0, B-1].

    The mapping is modulo 2π.
    """
    # Wrap phi into [0, 2π)
    phi_norm = phi % (2.0 * math.pi)
    # bin width in radians
    bin_width = 2.0 * math.pi / B
    return int(math.floor(phi_norm / bin_width))


def wrap_bins(bin_idx: int, B: int) -> int:
    """Wrap a bin index into the range [0, B-1]."""
    return bin_idx % B


def in_window(bin_idx: int, center: int, eps_bins: int, B: int) -> bool:
    """Return True if ``bin_idx`` lies within a window of half‑width ``eps_bins`` around ``center`` on a circular domain of size ``B``.

    Both ``bin_idx`` and ``center`` are integers in [0, B-1].
    """
    # compute circular distance between bin_idx and center
    dist = min((bin_idx - center) % B, (center - bin_idx) % B)
    return dist <= eps_bins


def jitter_sampler_factory(law: str, param: float) -> Callable[[np.random.Generator], float]:
    """Return a function that draws a jitter phase offset from the specified law.

    ``law`` must be either 'gaussian' or 'uniform'; ``param`` is the
    standard deviation (sigma) in radians for Gaussian, or the half width
    ``a`` for the uniform law.
    The returned function takes a ``numpy.random.Generator`` and returns a
    wrapped phase offset.
    """
    if law == "gaussian":
        return lambda rng: sample_gaussian(param, rng)
    elif law == "uniform":
        return lambda rng: sample_uniform(param, rng)
    else:
        raise ValueError(f"Unknown jitter law: {law}")
